#!/usr/bin/env python3
# A7 (REAL v4) — present-act engine with rails + per-window bootstrap
# Readout change: coherence C2 is computed on the OR of the last duty window per band (C2_window).
# We preserve C2_tick (last-tick) for transparency, but PASS uses C2_window.
#
# Control remains boolean/ordinal (no curve weights, no RNG in control).

import argparse, csv, hashlib, json, math, sys
from pathlib import Path

# ---------- utils ----------
def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)
def sha256_of_file(p: Path):
    h = hashlib.sha256()
    with p.open("rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""):
            h.update(chunk)
    return h.hexdigest()
def write_json(p: Path, obj): ensure_dir(p.parent); p.write_text(json.dumps(obj, indent=2), encoding="utf-8")
def write_csv(p: Path, header, rows):
    ensure_dir(p.parent)
    with p.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f); w.writerow(header); w.writerows(rows)
def load_json(path: Path):
    if not path.exists(): raise FileNotFoundError(f"Missing file: {path}")
    return json.loads(path.read_text(encoding="utf-8"))

# ---------- bands / rails ----------
def ring_bounds(nx, ny, outer_margin, frac_bounds):
    R_eff = min(nx, ny)/2.0 - outer_margin
    if R_eff <= 0: R_eff = max(nx,ny)/4.0
    return [(f0*R_eff, f1*R_eff) for (f0,f1) in frac_bounds], R_eff

def band_id_for_cell(xc, yc, cx, cy, r_bounds):
    r = math.hypot(xc-cx, yc-cy)
    for b,(r0,r1) in enumerate(r_bounds):
        if r0 <= r < r1: return b
    return -1

def rail_allows(ix, iy, t, mode, phase):
    if mode == "none": return True
    if mode == "alt2":
        parity = (ix + iy) & 1           # 0/1
        return (t & 1) == (parity ^ (phase & 1))
    return True

# ---------- engine (rails + per-window bootstrap) ----------
def run_engine(nx, ny, H, r_bounds, controls):
    cx, cy = (nx-1)/2.0, (ny-1)/2.0
    band = [[-1]*nx for _ in range(ny)]
    for iy in range(ny):
        for ix in range(nx):
            band[iy][ix] = band_id_for_cell(ix+0.5, iy+0.5, cx, cy, r_bounds)

    period  = [c["period"] for c in controls]
    duty    = [c["duty"] for c in controls]
    kthresh = [c["neighbor_threshold"] for c in controls]
    cool    = [c["cooldown_steps"] for c in controls]
    rmode   = [c.get("rail_mode","alt2") for c in controls]
    rphase  = [c.get("rail_phase",0)     for c in controls]

    commit_prev = [[0]*nx for _ in range(ny)]
    cooldown    = [[0]*nx for _ in range(ny)]
    commit_maps = []  # list of H grids

    commits_total = [0,0,0,0]
    cells_per_band = [0,0,0,0]
    for iy in range(ny):
        for ix in range(nx):
            b = band[iy][ix]
            if 0 <= b < 4: cells_per_band[b] += 1

    for t in range(H):
        now = [[0]*nx for _ in range(ny)]
        for iy in range(ny):
            for ix in range(nx):
                b = band[iy][ix]
                if b < 0: continue
                if cooldown[iy][ix] != 0:
                    eligible = False
                else:
                    duty_on = ((t % period[b]) < duty[b])
                    rail_on = rail_allows(ix, iy, t, rmode[b], rphase[b])
                    eligible = duty_on and rail_on
                    if eligible:
                        # neighbor support from t-1
                        supp = 0
                        if ix+1 < nx: supp += commit_prev[iy][ix+1]
                        if ix-1 >= 0: supp += commit_prev[iy][ix-1]
                        if iy+1 < ny: supp += commit_prev[iy+1][ix]
                        if iy-1 >= 0: supp += commit_prev[iy-1][ix]
                        needs_support = (kthresh[b] > 0)
                        first_active_window = (t % period[b] == 0) and duty_on
                        if needs_support and first_active_window:
                            # per-window bootstrap for K>0 bands
                            pass
                        else:
                            eligible = (supp >= kthresh[b])

                if eligible:
                    now[iy][ix] = 1
                    commits_total[b] += 1
                    cooldown[iy][ix] = cool[b]
                else:
                    if cooldown[iy][ix] > 0: cooldown[iy][ix] -= 1

        commit_maps.append(now)
        commit_prev = now

    # M1: commits per cell·tick
    M1 = []
    for b in range(4):
        denom = max(1, cells_per_band[b] * H)
        M1.append(commits_total[b] / denom)

    return commit_maps, band, cells_per_band, M1, period, duty

# ---------- coherence (tick vs window) ----------
def neighbor_agreement(g, band, b):
    ny, nx = len(g), len(g[0])
    agree = 0; pairs = 0
    for iy in range(ny):
        for ix in range(nx):
            if band[iy][ix] != b: continue
            if ix+1 < nx and band[iy][ix+1]==b:
                pairs += 1; agree += int(g[iy][ix] == g[iy][ix+1])
            if iy+1 < ny and band[iy+1][ix]==b:
                pairs += 1; agree += int(g[iy][ix] == g[iy+1][ix])
    return (agree/pairs) if pairs>0 else 1.0

def window_union_map(commit_maps, period_b, duty_b):
    """
    Build an occupancy map that is the OR over the last duty window for this band:
    ticks t in {t0 + k | k=0..duty_b-1}, where t0 = floor((H-1)/period_b)*period_b.
    If the last window is truncated at the end, take ticks from the previous full window.
    """
    H = len(commit_maps)
    ny, nx = len(commit_maps[0]), len(commit_maps[0][0])
    occ = [[0]*nx for _ in range(ny)]
    if H == 0: return occ
    t0 = ((H - 1) // period_b) * period_b
    ticks = []
    for k in range(duty_b):
        t = t0 + k
        if t < H:
            ticks.append(t)
        else:
            ticks.append(t0 - period_b + k)  # wrap to previous full window
    for t in ticks:
        g = commit_maps[t]
        for iy in range(ny):
            rowg = g[iy]; rowo = occ[iy]
            for ix in range(nx):
                if rowg[ix]: rowo[ix] = 1
    return occ, ticks

# ---------- R3 (immediate re-expression) ----------
def compute_R3(commit_maps, band, period, duty, kthresh, cool, rmode, rphase):
    # Re-run control to count 1->1 transitions (deterministic; same as v3)
    H = len(commit_maps)
    if H == 0: return [0.0, 0.0, 0.0, 0.0]
    ny, nx = len(commit_maps[0]), len(commit_maps[0][0])
    commit_prev = [[0]*nx for _ in range(ny)]
    cooldown    = [[0]*nx for _ in range(ny)]
    reexp_hits  = [0,0,0,0]; trials=[0,0,0,0]

    for t in range(H):
        now = [[0]*nx for _ in range(ny)]
        for iy in range(ny):
            for ix in range(nx):
                b = band[iy][ix]
                if b < 0: continue
                if cooldown[iy][ix] != 0:
                    eligible = False
                else:
                    duty_on = ((t % period[b]) < duty[b])
                    rail_on = rail_allows(ix, iy, t, rmode[b], rphase[b])
                    eligible = duty_on and rail_on
                    if eligible:
                        supp = 0
                        if ix+1 < nx: supp += commit_prev[iy][ix+1]
                        if ix-1 >= 0: supp += commit_prev[iy][ix-1]
                        if iy+1 < ny: supp += commit_prev[iy+1][ix]
                        if iy-1 >= 0: supp += commit_prev[iy-1][ix]
                        needs_support = (kthresh[b] > 0)
                        first_active_window = (t % period[b] == 0) and duty_on
                        if needs_support and first_active_window:
                            pass
                        else:
                            eligible = (supp >= kthresh[b])
                if eligible:
                    now[iy][ix] = 1
                    cooldown[iy][ix] = cool[b]
                else:
                    if cooldown[iy][ix] > 0: cooldown[iy][ix] -= 1
                if t > 0 and 0 <= b < 4:
                    trials[b] += 1
                    if commit_prev[iy][ix] == 1 and now[iy][ix] == 1:
                        reexp_hits[b] += 1
        commit_prev = now
    return [ (reexp_hits[b]/max(1,trials[b])) for b in range(4) ]

# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--diag", required=True)
    ap.add_argument("--out", required=True)
    args = ap.parse_args()

    out_dir = Path(args.out)
    metr, aud, info = out_dir/"metrics", out_dir/"audits", out_dir/"run_info"
    for d in (metr,aud,info): ensure_dir(d)

    manifest = load_json(Path(args.manifest))
    diag     = load_json(Path(args.diag))

    nx = int(manifest.get("domain",{}).get("grid",{}).get("nx",256))
    ny = int(manifest.get("domain",{}).get("grid",{}).get("ny",256))
    H  = int(manifest.get("domain",{}).get("ticks",128))

    outer_margin = int(diag.get("ring",{}).get("outer_margin",8))
    frac_bounds  = diag.get("bands",{}).get("frac_bounds",
                    [[0.00,0.35],[0.35,0.60],[0.60,0.85],[0.85,1.00]])
    r_bounds, R_eff = ring_bounds(nx, ny, outer_margin, frac_bounds)

    controls = diag.get("controls",{}).get("per_band", [
        {"period":6, "duty":3, "neighbor_threshold":1, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0},
        {"period":6, "duty":2, "neighbor_threshold":1, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0},
        {"period":8, "duty":1, "neighbor_threshold":1, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0},
        {"period":10,"duty":1, "neighbor_threshold":0, "cooldown_steps":1, "rail_mode":"alt2", "rail_phase":0}
    ])
    period  = [c["period"] for c in controls]
    duty    = [c["duty"] for c in controls]
    kthresh = [c["neighbor_threshold"] for c in controls]
    cool    = [c["cooldown_steps"] for c in controls]
    rmode   = [c.get("rail_mode","alt2") for c in controls]
    rphase  = [c.get("rail_phase",0)     for c in controls]

    commit_maps, band, cells, M1, period, duty = run_engine(nx,ny,H,r_bounds,controls)

    # expected M1 (rails halves duty)
    M1_exp = []
    for b in range(4):
        M1_exp.append((duty[b]/float(period[b])) * (0.5 if rmode[b]=="alt2" else 1.0))

    # C2_tick (last tick)
    last_grid = commit_maps[-1] if commit_maps else [[0]*nx for _ in range(ny)]
    C2_tick = [ neighbor_agreement(last_grid, band, b) for b in range(4) ]

    # C2_window (OR of last duty window per band)
    C2_window = []
    window_ticks = []
    for b in range(4):
        occ, ticks = window_union_map(commit_maps, period[b], duty[b])
        window_ticks.append(ticks)
        C2_window.append( neighbor_agreement(occ, band, b) )

    # R3
    R3 = compute_R3(commit_maps, band, period, duty, kthresh, cool, rmode, rphase)

    # PASS rules
    min_counts = int(diag.get("tolerances",{}).get("min_counts_per_band",5000))
    sanity_min = float(diag.get("tolerances",{}).get("sanity_min",0.90))

    counts_ok = True
    for b in range(4):
        approx_total = M1[b] * max(1,cells[b]) * H
        if approx_total < min_counts: counts_ok = False; break

    M1_order     = all(M1[i]      >  M1[i+1]      for i in range(3))       # strictly decreasing outward
    C2w_order    = all(C2_window[i] >= C2_window[i+1]-1e-12 for i in range(3)) # non-increasing outward
    R3_order     = all(R3[i]      <= R3[i+1]+1e-12 for i in range(3))      # non-decreasing outward

    checks = [M1_order, C2w_order, R3_order, counts_ok]
    sanity_score = sum(1 for v in checks if v) / len(checks)
    PASS = bool(all(checks) and sanity_score >= sanity_min)

    # metrics CSV
    names = ["beta_minus_2","beta_minus_1","beta_0","beta_plus_1"]
    rows = []
    for b in range(4):
        rows.append([b, names[b], cells[b],
                     round(M1[b],6),
                     round(C2_tick[b],6),
                     round(C2_window[b],6),
                     round(R3[b],6),
                     round(M1_exp[b],6),
                     json.dumps(window_ticks[b])])
    write_csv(metr/"band_panel.csv",
              ["band_id","band_name","cells",
               "M1_commit_density","C2_tick","C2_window","R3_reexpression",
               "M1_expected","window_ticks"],
              rows)

    # audit JSON
    write_json(aud/"band_classifier.json", {
        "nx": nx, "ny": ny, "ticks": H, "R_eff": R_eff,
        "controls": controls,
        "metrics": {
            "M1": M1, "M1_expected": M1_exp,
            "C2_tick": C2_tick, "C2_window": C2_window,
            "R3": R3
        },
        "coherence_metric": "window_or",
        "window_ticks": window_ticks,
        "ordering_checks": {
            "M1_order": M1_order, "C2_window_order": C2w_order,
            "R3_order": R3_order, "counts_ok": counts_ok
        },
        "sanity_score": sanity_score, "sanity_min": sanity_min,
        "PASS": PASS
    })

    # provenance
    write_json(info/"hashes.json", {
        "manifest_hash": sha256_of_file(Path(args.manifest)),
        "diag_hash":     sha256_of_file(Path(args.diag)),
        "engine_entrypoint": f"python {Path(sys.argv[0]).name} --manifest <...> --diag <...> --out <...>"
    })

    # stdout summary
    print("A7 SUMMARY (REAL v4):", json.dumps({
        "M1":[round(v,6) for v in M1],
        "C2_tick":[round(v,6) for v in C2_tick],
        "C2_window":[round(v,6) for v in C2_window],
        "R3":[round(v,6) for v in R3],
        "sanity_score": round(sanity_score,3),
        "PASS": PASS,
        "audit_path": str((aud/'band_classifier.json').as_posix())
    }))

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        out_dir = None
        for i,a in enumerate(sys.argv):
            if a == "--out" and i+1 < len(sys.argv):
                out_dir = Path(sys.argv[i+1]); break
        if out_dir:
            ensure_dir(out_dir/"audits")
            write_json(out_dir/"audits"/"band_classifier.json",
                       {"PASS": False, "failure_reason": f"Unexpected error: {type(e).__name__}: {e}"})
        raise
